#!/usr/bin/env python3
# Bin-by-bin runner for prestack_kids.py with progress, resume, and parallelism.
import os, json, time, subprocess, queue, threading, sys
from pathlib import Path

# ---- CONFIG (leave as-is unless you need to change paths) ----
KIDS = "data/KiDS_DR4.1_ugriZYJHKs_SOM_gold_WL_cat.fits"
RG_EDGES = "1.5,3,5,8,12"
MS_EDGES = "10.2,10.5,10.8,11.1"
B_EDGES  = "10,15,22,32,46,66,95,137,198,285,410,592,855,1236,1787,2583"
PY = os.environ.get("PYTHON", sys.executable)
# --------------------------------------------------------------

def _norm_bin(s: str) -> str:
    return s.replace('[','').replace(']','').replace(' ','').replace(',','-')

def _read_csv(path):
    import pandas as pd
    return pd.read_csv(path)

def _write_csv(df, path):
    df.to_csv(path, index=False)

def _split_bins(in_csv: str, out_dir: Path) -> list:
    import pandas as pd
    df = pd.read_csv(in_csv)
    rg = 'RG_bin' if 'RG_bin' in df.columns else ('R_G_bin' if 'R_G_bin' in df.columns else None)
    ms = 'Mstar_bin' if 'Mstar_bin' in df.columns else ('M_bin' if 'M_bin' in df.columns else None)
    if not rg or not ms:
        raise SystemExit(f"Missing RG_bin/Mstar_bin in {in_csv}")
    out_dir.mkdir(parents=True, exist_ok=True)
    idx = []
    for r in sorted(df[rg].astype(str).unique()):
        for m in sorted(df[ms].astype(str).unique()):
            sub = df[(df[rg].astype(str)==r) & (df[ms].astype(str)==m)]
            if sub.empty: 
                continue
            fn = out_dir / f"lenses_rg{_norm_bin(r)}_ms{_norm_bin(m)}.csv"
            sub.to_csv(fn, index=False)
            idx.append({"rg_bin": r, "ms_bin": m, "rows": int(len(sub)), "path": str(fn)})
            print(f"[bin] {r} × {m}: {len(sub):,} -> {fn}")
    (out_dir / "index.json").write_text(json.dumps(idx, indent=2))
    print(f"[OK] wrote {len(idx)} bin files -> {out_dir/'index.json'}")
    return idx

def _concat(parts: list, metas: list, target: str):
    import pandas as pd
    out_all  = f"data/prestacked_stacks_{'lens' if target=='lenses' else 'rand'}.csv"
    meta_all = f"data/prestacked_meta_{'lens' if target=='lenses' else 'rand'}.csv"
    dfs = [pd.read_csv(p) for p in parts if os.path.exists(p)]
    mfs = [pd.read_csv(m) for m in metas if os.path.exists(m)]
    if dfs: pd.concat(dfs, ignore_index=True).to_csv(out_all, index=False)
    if mfs: pd.concat(mfs, ignore_index=True).to_csv(meta_all, index=False)
    print(f"[OK] concatenated -> {out_all} ({sum(map(len, dfs)) if dfs else 0} rows), {meta_all}")

def run(target="lenses", workers=None, use_m_corr=True, resume=True):
    if target not in ("lenses", "randoms"):
        raise SystemExit("target must be 'lenses' or 'randoms'")
    in_csv = "data/lenses.csv" if target=="lenses" else "data/lenses_random.csv"
    base   = Path(f"work/bins/{target}"); base.mkdir(parents=True, exist_ok=True)
    logs   = Path("logs"); logs.mkdir(exist_ok=True)
    prog   = Path(f"data/prestack_progress_{target}.json")

    # Split lenses (or reuse)
    idx_path = base/"index.json"
    if idx_path.exists():
        bins = json.loads(idx_path.read_text())
    else:
        bins = _split_bins(in_csv, base)

    # Resume
    completed = set()
    if resume and prog.exists():
        try:
            prev = json.loads(prog.read_text())
            completed = set(prev.get("completed", []))
            print(f"[resume] skipping {len(completed)} already-completed bins")
        except Exception:
            pass

    # Jobs
    jobs, parts, metas = [], [], []
    for b in bins:
        key  = f"{b['rg_bin']}|{b['ms_bin']}"
        stem = Path(b["path"]).stem
        part = base / f"prestack_{target}_{stem}.csv"
        meta = base / f"meta_{target}_{stem}.csv"
        log  = logs / f"prestack_{target}_{stem}.log"
        parts.append(str(part)); metas.append(str(meta))
        if key in completed and part.exists():
            continue
        cmd = [
            PY, "scripts/prestack_kids.py",
            "--kids", KIDS, "--lenses", b["path"],
            "--out", str(part), "--out-meta", str(meta),
            "--rg-bins", RG_EDGES, "--mstar-bins", MS_EDGES,
            "--b-bins-arcsec", B_EDGES, "--min-zsep", "0.1"
        ]
        if use_m_corr:
            cmd.append("--use-m-corr")
        jobs.append({"key": key, "cmd": cmd, "log": str(log), "rows": b["rows"]})

    total = len(bins)
    print(f"[run] target={target} bins={total} plan={len(jobs)} workers={workers or os.cpu_count()//3 or 2}")
    if not jobs:
        print("[skip] nothing to do; concatenating existing parts …")
        _concat(parts, metas, target)
        return

    q = queue.Queue()
    for j in jobs: q.put(j)
    results = []
    N = max(1, workers or (os.cpu_count()//3 or 2))

    def worker():
        while True:
            try:
                job = q.get_nowait()
            except queue.Empty:
                return
            t0 = time.time()
            with open(job["log"], "w") as lf:
                rc = subprocess.run(job["cmd"], stdout=lf, stderr=subprocess.STDOUT).returncode
            job["rc"] = rc; job["dt"] = time.time() - t0
            results.append(job); q.task_done()

    threads = [threading.Thread(target=worker, daemon=True) for _ in range(N)]
    for t in threads: t.start()

    started = time.time()
    last_flush = 0
    while any(t.is_alive() for t in threads):
        time.sleep(0.5)
        if time.time() - last_flush > 5:
            done = len(completed) + sum(1 for r in results if r.get("rc", 1)==0)
            dts  = [r["dt"] for r in results if r.get("rc",1)==0 and r.get("dt",0)>0]
            eta  = (sum(dts)/len(dts))*(total-done) if dts else None
            state = {
                "target": target, "total": total, "done": done,
                "completed": sorted(list(completed)) + [r["key"] for r in results if r.get("rc",1)==0],
                "elapsed_sec": round(time.time()-started, 1),
                "eta_sec": round(eta, 1) if eta else None
            }
            prog.write_text(json.dumps(state, indent=2))
            last_flush = time.time()

    # finalize
    errs = [r for r in results if r.get("rc",0)!=0]
    for r in results:
        if r.get("rc",0)==0:
            completed.add(r["key"])
    prog.write_text(json.dumps({
        "target": target, "total": total, "done": len(completed),
        "completed": sorted(list(completed)),
        "elapsed_sec": round(time.time()-started,1),
        "errors": [{"key": e["key"], "log": e["log"]} for e in errs]
    }, indent=2))
    if errs:
        print(f"[ERR] {len(errs)} bin(s) failed; see logs/*.log")
        raise SystemExit(1)

    _concat(parts, metas, target)
    print("[OK] orchestrator finished.")

if __name__ == "__main__":
    # Minimal CLI so you don't need another script:
    #   python scripts/t3_prestack_orchestrator.py lenses 6
    #   python scripts/t3_prestack_orchestrator.py randoms 6
    import sys
    if len(sys.argv) < 2:
        raise SystemExit("usage: python scripts/t3_prestack_orchestrator.py <lenses|randoms> [workers]")
    tgt = sys.argv[1].lower()
    w   = int(sys.argv[2]) if len(sys.argv) > 2 else None
    run(target=tgt, workers=w, use_m_corr=True, resume=True)
